/*
 * Copyright (c) 2022-2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cutlass/arch/arch.h>

#include <cub/cub.cuh>
#include <cute/arch/cluster_sm90.hpp>
#include <type_traits>

#include "DevKernel.h"
#include "RoutingKernel.h"
#include "RoutingKernelTopK.cuh"

////////////////////////////////////////////////////////////////////////////////////////////////////
namespace moe::dev {

////////////////////////////////////////////////////////////////////////////////////////////////////
namespace routing {

namespace cg = cooperative_groups;

////////////////////////////////////////////////////////////////////////////////////////////////////

static constexpr int WarpSize = 32;
static constexpr int NumBlocksPerCluster = 8;
// Performance tuning knob.
static constexpr int NumEltsPerOffsetTilePerThread = 8;

////////////////////////////////////////////////////////////////////////////////////////////////////

static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; }

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
__host__ __device__ constexpr T mulLog2(T a, T bLog2) {
  return a << bLog2;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
__host__ __device__ constexpr T divUpLog2(T a, T bLog2) {
  return ((a + (1 << bLog2) - 1) >> bLog2);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) {
  return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T mulTileN(T a, T tileN) {
  return a * tileN;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T divUpTileN(T a, T tileN) {
  return (a + tileN - 1) / tileN;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T divUpMulTileN(T a, T tileN) {
  return divUpTileN(a, tileN) * tileN;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

__host__ __device__ constexpr int32_t getBits(int32_t value, int idx) {
  int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000;
  return (value & mask) >> (idx * 8);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool IsZero = false>
__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) {
  if constexpr (!IsZero) {
    int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
    value &= mask;
  }
  value |= (newBits << (idx * 8));
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType>
__device__ void initArr(int startIdx, int numElts, int stride, DataType* arr, DataType value) {
  if (arr != nullptr) {
    for (int i = startIdx; i < numElts; i += stride) {
      arr[i] = value;
    }
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType, int VecSize>
__device__ void calcSoftmax(cg::thread_block_tile<WarpSize> const& warp,
                            DataType (&scores)[VecSize]) {
  // Compute in float to support half/bfloat16 inputs safely.
  float maxScore = -INFINITY;
  float sumScore = 0.f;
  // Get the max score for each token
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
    float si = static_cast<float>(scores[i]);
    maxScore = si >= maxScore ? si : maxScore;
  }
  maxScore = cg::reduce(warp, maxScore, cg::greater<float>());

  // Get the summation of scores for each token
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
    float si = static_cast<float>(scores[i]);
    float e = expf(si - maxScore);
    scores[i] = static_cast<DataType>(e);
    sumScore += e;
  }
  sumScore = cg::reduce(warp, sumScore, cg::plus<float>());

  // Normalize the scores
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
    float si = static_cast<float>(scores[i]) / sumScore;
    scores[i] = static_cast<DataType>(si);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename DataType>
__device__ DataType calcSoftmax(cg::thread_block_tile<WarpSize> const& warp, DataType score,
                                int32_t laneIdx, int32_t NumTopExperts) {
  DataType maxScore = DataType{-INFINITY};
  if (laneIdx < NumTopExperts) {
    maxScore = score >= maxScore ? score : maxScore;
  }
  maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());

  float sumScore = float{0.f};
  float newScore;
  // Get the summation of scores for each token
  if (laneIdx < NumTopExperts) {
    newScore = static_cast<float>(score) - static_cast<float>(maxScore);
    newScore = static_cast<float>(exp(newScore));
    sumScore += newScore;
  }
  sumScore = cg::reduce(warp, sumScore, cg::plus<float>());

  if (laneIdx < NumTopExperts) {
    score = static_cast<DataType>(newScore / sumScore);
  }

  return score;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams, typename BaseType, int NumThreads, int NumWarps,
          int MaxNumTopExperts, bool LoadExpertIdxFromGlobal = false>
__device__ void routingPermutation(KernelParams params,
                                   PackedScoreIdx<BaseType>* smemPackedScoreIdx,
                                   int32_t const warpIdx, uint32_t const clusterBlockRank) {
  using OutputT = typename KernelParams::OutputT;
  using TypePacked = PackedScoreIdx<BaseType>;

  static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
  // Number of threads in the cluster.
  static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
  // same as max num tokens
  static constexpr int MaxExpandedIdxPerThread =
      (MaxNumTokensSingleCluster * MaxNumTopExperts + NumThreadsPerCluster - 1) /
      NumThreadsPerCluster;

  // Needed for the exclusive sum of token offsets.
  // Note: the scan might include more bins than needed, with bin counts of 0 to pad
  using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
  __shared__ typename Scan::TempStorage tempStorage;

  uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
  auto expandedIdxSize = params.mNumTokens * params.mTopK;

  // number of experts is bounded by number of threads
  __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
  __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];

  // pre-fill the counts with 0
  if (threadIdx.x < params.mNumExperts) {
    smemExpertCount[threadIdx.x] = 0;
  }
  __syncthreads();

  // each thread keeps some number of "expanded indexes" assigned to it
  // note that expanded indexes simply represent tokens here.
  // for each of these, we keep the associated expert and offset within expert in registers
  int32_t expertIndexes[MaxExpandedIdxPerThread];
  int32_t expertOffsets[MaxExpandedIdxPerThread];
  auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;

  // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
  // time, and branch between a fast path without bound checks and a slow path with bound checks.
  // TODO(mjoux): potentially add this back for perf tuning
  // int constexpr IterStride = 4;
  // static_assert(MaxExpandedIdxPerThread % IterStride == 0);

  // Define a lambda to avoid code duplication in both branches.
  auto loopBody = [&](int ii, int expandedIdx) {
    TypePacked scoreIdx;
    if constexpr (LoadExpertIdxFromGlobal) {
      if (params.mPtrTopKIds != nullptr) {
        scoreIdx = TypePacked{static_cast<BaseType>(params.mPtrTopKWeights[expandedIdx]),
                              static_cast<int16_t>(params.mPtrTopKIds[expandedIdx])};
      } else {
        scoreIdx = TypePacked{static_cast<BaseType>(params.mPtrTopKPacked[expandedIdx].score),
                              static_cast<int16_t>(params.mPtrTopKPacked[expandedIdx].idx)};
      }
    } else {
      TypePacked const* remoteSmem = cg::cluster_group::map_shared_rank(
          smemPackedScoreIdx, expandedIdx / (NumWarps * params.mTopK));
      scoreIdx = remoteSmem[expandedIdx % (NumWarps * params.mTopK)];
    }

    expertIndexes[ii] = scoreIdx.idx;
    // check whether this expert is local to our GPU at all and ignore if not
    auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
    auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
                         (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
    expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
    if (params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) {
      params.mPtrTopKWeights[expandedIdx] = OutputT{scoreIdx.score};
    }
  };

  int constexpr IterStride = 4;
#pragma unroll
  for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) {
    // Whether it's safe to do multiple iterations without bound checks.
    bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize;
    if (takeFastPath) {
#pragma unroll
      for (int32_t jj = 0; jj < IterStride; jj++) {
        int const ii = ii0 + jj;
        auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
        loopBody(ii, expandedIdx);
      }
    } else {
      bool doBreak = false;
#pragma unroll
      for (int32_t jj = 0; jj < IterStride; jj++) {
        int const ii = ii0 + jj;
        auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
        if (expandedIdx >= expandedIdxSize) {
          doBreak = true;
          break;
        }
        loopBody(ii, expandedIdx);
      }
      if (doBreak) {
        break;
      }
    }
  }
  // Make local histogram (token counts per expert) available to all threads in the cluster.
  __cluster_barrier_arrive();
  __cluster_barrier_wait();

  //
  // Each thread now represents one expert
  //

  // Total number of tokens for this expert.
  int32_t count = 0;
  // Per-expert offset for this block.
  int32_t blockExpertOffset = 0;

  if (threadIdx.x < params.mNumExperts) {
    // Get the histogram bin from each rank for this expert.
    int32_t expertCounts[NumBlocksPerCluster];
#pragma unroll
    for (int rank = 0; rank < NumBlocksPerCluster; rank++) {
      int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
      expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0;
    }

    // Compute an exclusive prefix sum of the block-local count.
#pragma unroll
    for (int rank = 0; rank < NumBlocksPerCluster; rank++) {
      if (rank == clusterBlockRank) {
        blockExpertOffset = count;
      }
      count += expertCounts[rank];
    }
  }

  // Arrive: we do not access distributed shared memory after this point.
  __cluster_barrier_arrive();

  // Compute the runtime config for projections
  // Whether or not an expert is local is taken into account when smemExpertCount is computed
  // so we do not need to take it into account here.

  int32_t numCta;
  if constexpr (KernelParams::isPow2) {
    numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
  } else {
    numCta = divUpTileN<int32_t>(count, params.mTileTokensDim);
  }

  int32_t ctaOffset;
  int32_t numNonExitingCtas;
  Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);

  if (threadIdx.x < params.mNumExperts) {
    // Strided loop to share this work between blocks.
    for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster) {
      const int32_t localExpertIdx =
          (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
      params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
      int32_t mnLimit1;
      int32_t mnLimit2;
      if constexpr (KernelParams::isPow2) {
        mnLimit1 = mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2);
        mnLimit2 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count;
      } else {
        mnLimit1 = mulTileN<int32_t>(ctaOffset + cta + 1, params.mTileTokensDim);
        mnLimit2 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim) + count;
      }
      params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2);
    }

    // get the padded offset associated with this expert
    int32_t offset;
    if constexpr (KernelParams::isPow2) {
      offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
    } else {
      offset = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim);
    }
    // write expert offsets to shared
    smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
  }

  // write out padded count
  if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) {
    int32_t permutedIdxSize;
    if constexpr (KernelParams::isPow2) {
      permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
    } else {
      permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
    }
    params.mPtrPermutedIdxSize[0] = permutedIdxSize;
    params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
  }

  // make expert offsets available to all threads
  __syncthreads();

  // Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
  // Note: I observed a perf benefit to doing this before the final loop so the compiler can
  // implement break with EXIT.
  __cluster_barrier_wait();

  // trigger the secondary kernel when using PDL
  // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
  // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
  // TODO: this is not sufficient to ensure visibility in the next kernel!
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  if constexpr (KernelParams::UsePdl) {
    cudaTriggerProgrammaticLaunchCompletion();
  }
#endif

  // each thread has the same "expanded indexes" assigned to it as above
  // at this point, we know the final offsets of experts and the offsets within
  // experts, which allows writing the final index values

#pragma unroll
  for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) {
    auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
    if (expandedIdx >= expandedIdxSize) {
      break;
    }
    auto expertIdx = expertIndexes[ii];
    // check whether this expert is local to our GPU at all
    auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
    auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
                         (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
    auto tokenIdx = expandedIdx / params.mTopK;
    auto permutedIdx =
        isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
    if (params.mPtrExpandedIdxToPermutedIdx != nullptr) {
      params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
    }
    if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) {
      params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
    }
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
// variants can handle): in order to minimize the amount of data to exchange through global memory,
// we will compute the local histograms in smem twice: the first kernel will get us the total number
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
// element and tile offsets.
//
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
// inefficient if we have one CTA per token doing a single global atomic.
template <typename KernelParams>
__global__ void __launch_bounds__(KernelParams::MaxNumExperts)
    routingIndicesHistogramKernel(KernelParams params) {
  using OutputT = typename KernelParams::OutputT;

  // number of experts is bounded by number of threads
  __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts];

  // For unrolling.
  uint32_t constexpr NumEltsPerThread = 8;

  // Pre-fill the counts with 0
  if (threadIdx.x < params.mNumExperts) {
    smemExpertCount[threadIdx.x] = 0;
  }
  __syncthreads();

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  // Wait on primary grid and trigger secondary kernel.
  if constexpr (KernelParams::UsePdl) {
    cudaGridDependencySynchronize();
    cudaTriggerProgrammaticLaunchCompletion();
  }
#endif  // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

  uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
  uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;

  uint32_t const gridBlockOffset = blockIdx.x * KernelParams::MaxNumExperts;
  uint32_t const gridStride = gridDim.x * KernelParams::MaxNumExperts;

  // Define a lambda to avoid code duplication in branches.
  auto loopBody = [&](int expandedIdx) {
    PackedScoreIdx<OutputT> scoreIdx;
    int idx;
    if (params.mPtrTopKIds != nullptr) {
      idx = params.mPtrTopKIds[expandedIdx];
    } else {
      // If params.mPtrTopKIds != nullptr, we don't need to store the weights
      if (params.mPtrTopKWeights != nullptr) {
        scoreIdx = params.mPtrTopKPacked[expandedIdx];
        idx = scoreIdx.idx;
        params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
      }
    }
    // check whether this expert is local to our GPU at all and ignore if not
    auto localExpertIdx = idx - params.mLocalExpertsStartIdx;
    auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
                         (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
    if (isLocalExpert) {
      atomicAdd(&smemExpertCount[idx], 1);
    }
  };

  // Grid-stride loop.
  for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
       expandedIdx0 += gridStride * NumEltsPerThread) {
    // Fast path if bound checks aren't necessary
    if (expandedIdx0 + NumEltsPerThread * KernelParams::MaxNumExperts <= expandedIdxSize) {
#pragma unroll
      for (uint32_t ii = 0; ii < NumEltsPerThread; ii++) {
        uint32_t expandedIdx = expandedIdx0 + ii * KernelParams::MaxNumExperts + threadIdx.x;
        loopBody(expandedIdx);
      }
    } else {
      for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
           expandedIdx += KernelParams::MaxNumExperts) {
        loopBody(expandedIdx);
      }
    }
  }
  __syncthreads();

  //
  // Each thread now represents one expert
  //

  // Reduce histograms with atomics.
  if (threadIdx.x < params.mNumExperts) {
    int32_t const localExpertCount = smemExpertCount[threadIdx.x];
    atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename KernelParams>
__global__ void __launch_bounds__(KernelParams::MaxNumExperts)
    routingIndicesOffsetsKernel(KernelParams params) {
  using OutputT = typename KernelParams::OutputT;

  // number of experts is bounded by number of threads
  __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[KernelParams::MaxNumExperts];
  __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts];
  __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[KernelParams::MaxNumExperts];
  // needed for the exclusive sum of token offsets
  using Scan = cub::BlockScan<int32_t, KernelParams::MaxNumExperts, cub::BLOCK_SCAN_WARP_SCANS>;
  __shared__ typename Scan::TempStorage tempStorage;
  static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
  static constexpr int MaxExpandedIdxPerBlock =
      KernelParams::MaxNumExperts * MaxExpandedIdxPerThread;

  int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);

  uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
  uint32_t const numTiles =
      (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  // Wait on primary grid.
  if constexpr (KernelParams::UsePdl) {
    cudaGridDependencySynchronize();
  }
#endif  // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

  // The expert offsets are common to all tiles of all blocks.
  // Load the histogram, scan it and write offsets to shared memory.
  // Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
  // the scan, with PDL?

  //
  // Each thread represents one expert.
  //

  // Get total count for this expert.
  int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;

  // Compute the runtime config for projections
  // Whether or not an expert is local is taken into account when the histogram is computed
  // so we do not need to take it into account here.
  // const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
  int32_t numCta;
  if constexpr (KernelParams::isPow2) {
    numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
  } else {
    numCta = divUpTileN<int32_t>(count, params.mTileTokensDim);
  }
  int32_t ctaOffset;
  int32_t numNonExitingCtas;
  Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);

  if (threadIdx.x < params.mNumExperts) {
    // Get the padded offset associated with this expert
    int32_t offset;
    if constexpr (KernelParams::isPow2) {
      offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
    } else {
      offset = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim);
    }

    // Write expert offsets to shared
    smemExpertOffset[threadIdx.x] = offset;
  }

  // Sync to make expert offsets available to all threads.
  __syncthreads();

  // The first block writes out padded count
  if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 &&
      cute::elect_one_sync()) {
    int32_t permutedIdxSize;
    if constexpr (KernelParams::isPow2) {
      permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
    } else {
      permutedIdxSize = mulTileN<int32_t>(numNonExitingCtas, params.mTileTokensDim);
    }
    params.mPtrPermutedIdxSize[0] = permutedIdxSize;
    params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
  }

  if (threadIdx.x < params.mNumExperts) {
    // Strided loop to share this work between blocks.
    for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x) {
      const int32_t localExpertIdx =
          (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
      params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
      int32_t mnLimit1;
      int32_t mnLimit2;
      if constexpr (KernelParams::isPow2) {
        mnLimit1 = mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2);
        mnLimit2 = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count;
      } else {
        mnLimit1 = mulTileN<int32_t>(ctaOffset + cta + 1, params.mTileTokensDim);
        mnLimit2 = mulTileN<int32_t>(ctaOffset, params.mTileTokensDim) + count;
      }
      params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2);
    }
  }

  //
  // Now loop on indices and compute offsets.
  //

  // Grid-stride loop on 1D "tiles" of input indices.
  for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x) {
    if (tileIdx > 0) {
      // Sync for safe reuse of smem buffers.
      __syncthreads();
    }

    // Pre-fill the counts with 0
    if (threadIdx.x < params.mNumExperts) {
      smemExpertCount[threadIdx.x] = 0;
    }
    __syncthreads();

    // each thread keeps has some number of "expanded indexes" assigned to it
    // for each of these, we keep the associated expert and offset within expert in registers
    int32_t expertIndexes[MaxExpandedIdxPerThread];
    int32_t expertOffsets[MaxExpandedIdxPerThread];
    auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;

    // Define a lambda to avoid code duplication in branches.
    auto loopBody = [&](int ii, int expandedIdx) {
      expertIndexes[ii] = params.mPtrTopKIds ? params.mPtrTopKIds[expandedIdx]
                                             : params.mPtrTopKPacked[expandedIdx].idx;
      // check whether this expert is local to our GPU at all and ignore if not
      auto localExpertIdx = expertIndexes[ii] - params.mLocalExpertsStartIdx;
      auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
                           (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
      expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIndexes[ii], 1) : 0;
    };

    // For all tiles but the last, all indices are in bounds.
    if (tileIdx < numTiles - 1) {
#pragma unroll
      for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) {
        auto expandedIdx =
            tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x;
        loopBody(ii, expandedIdx);
      }
    } else {
      // For the last tile, we need to exit the loop when out of bounds.
      // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
      // time, and branch between a fast path without bound checks and a slow path with bound checks
      int constexpr IterStride = 4;
      static_assert(MaxExpandedIdxPerThread % IterStride == 0);

#pragma unroll
      for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) {
        // Whether it's safe to do multiple iterations without bound checks.
        bool const takeFastPath =
            tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * KernelParams::MaxNumExperts <=
            expandedIdxSize;
        if (takeFastPath) {
#pragma unroll
          for (int32_t jj = 0; jj < IterStride; jj++) {
            int const ii = ii0 + jj;
            auto expandedIdx =
                tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x;
            loopBody(ii, expandedIdx);
          }
        } else {
          bool doBreak = false;
#pragma unroll
          for (int32_t jj = 0; jj < IterStride; jj++) {
            int const ii = ii0 + jj;
            auto expandedIdx =
                tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x;
            if (expandedIdx >= expandedIdxSize) {
              doBreak = true;
              break;
            }
            loopBody(ii, expandedIdx);
          }
          if (doBreak) {
            break;
          }
        }
      }
    }

    // Make local histogram (token counts per expert) available to all threads in the block.
    __syncthreads();

    //
    // Each thread now represents one expert
    //

    if (threadIdx.x < params.mNumExperts) {
      // Add the local bin count to the common bin count and get a per-CTA offset. We use the second
      // half of the histogram buffer for this histogram, because the first half already holds the
      // reduced histogram from the previous kernel.
      int32_t const localExpertCount = smemExpertCount[threadIdx.x];
      int32_t const tileExpertOffset =
          atomicAdd(&params.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount);

      // Make per-expert tile offsets available to all threads in the block.
      smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
    }
    __syncthreads();

    // Add tile offset and element offset and write to global memory.
    auto storeLoopBody = [&](int ii, int expandedIdx) {
      int32_t expertIdx = expertIndexes[ii];
      // check whether this expert is local to our GPU at all
      auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
      auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
                           (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
      auto tokenIdx = expandedIdx / params.mTopK;
      auto permutedIdx =
          isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
      if (params.mPtrExpandedIdxToPermutedIdx != nullptr) {
        params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
      }
      if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) {
        params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
      }
    };
    // Bound checks only in last tile.
    if (tileIdx < numTiles - 1) {
#pragma unroll
      for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) {
        auto expandedIdx =
            tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x;
        storeLoopBody(ii, expandedIdx);
      }
    } else {
#pragma unroll
      for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) {
        auto expandedIdx =
            tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x;
        if (expandedIdx >= expandedIdxSize) {
          break;
        }
        storeLoopBody(ii, expandedIdx);
      }
    }
  }

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  // Trigger secondary kernel.
  // Note: this does not guarantee the visibility of prior writes unless the consumer executes a
  // dependency sync.
  if constexpr (KernelParams::UsePdl) {
    cudaTriggerProgrammaticLaunchCompletion();
  }
#endif  // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename KernelParams>
__global__ void __launch_bounds__(KernelParams::MaxNumExperts)
    routingInitExpertCounts(KernelParams params) {
  // initialize the mPtrExpertCounts
  int32_t expertCountsNum = 2 * params.mNumExperts;
  int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x;
  int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  // Wait on primary grid.
  if constexpr (KernelParams::UsePdl) {
    cudaGridDependencySynchronize();
  }
#endif  // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

  initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  // Wait on primary grid.
  if constexpr (KernelParams::UsePdl) {
    cudaTriggerProgrammaticLaunchCompletion();
  }
#endif  // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}
}  // namespace routing
}  // namespace moe::dev
